{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "51466c8d-8ce4-4b3d-be4e-18fdbeda5f53",
      "metadata": {},
      "source": [
        "# How to create a custom checkpointer using Redis\n",
        "\n",
        "When creating LangGraph agents, you can also set them up so that they persist their state. This allows you to do things like interact with an agent multiple times and have it remember previous interactions. Make sure that you have Redis running on port `6379` for going through this tutorial\n",
        "\n",
        "This example shows how to use `Redis` as the backend for persisting checkpoint state.\n",
        "\n",
        "NOTE: this is just an example implementation. You can implement your own checkpointer using a different database or modify this one as long as it conforms to the `BaseCheckpointSaver` interface."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0aac2830",
      "metadata": {},
      "source": [
        "## Install the necessary libraries for Redis on Python"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "faadfb1b-cebe-4dcf-82fd-34044c380bc4",
      "metadata": {},
      "outputs": [],
      "source": [
        "%%capture --no-stderr\n",
        "%pip install -U redis langgraph"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a6a4e417",
      "metadata": {},
      "source": [
        "## Checkpointer implementation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "a35dba8e-5562-4803-ad80-160f53592dd7",
      "metadata": {},
      "outputs": [],
      "source": [
        "\"\"\"Implementation of a langgraph checkpoint saver using Redis.\"\"\"\n",
        "from contextlib import asynccontextmanager, contextmanager\n",
        "from typing import Any, AsyncGenerator, Generator, Union, Tuple, Optional\n",
        "\n",
        "import redis\n",
        "from redis.asyncio import Redis as AsyncRedis, ConnectionPool as AsyncConnectionPool\n",
        "from langchain_core.runnables import RunnableConfig\n",
        "from langgraph.checkpoint.base import BaseCheckpointSaver\n",
        "from langgraph.serde.jsonplus import JsonPlusSerializer\n",
        "from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, CheckpointTuple\n",
        "import logging\n",
        "\n",
        "logging.basicConfig(level=logging.INFO)\n",
        "logger = logging.getLogger(__name__)\n",
        "\n",
        "\n",
        "class JsonAndBinarySerializer(JsonPlusSerializer):\n",
        "    def _default(self, obj: Any) -> Any:\n",
        "        if isinstance(obj, (bytes, bytearray)):\n",
        "            return self._encode_constructor_args(\n",
        "                obj.__class__, method=\"fromhex\", args=[obj.hex()]\n",
        "            )\n",
        "        return super()._default(obj)\n",
        "\n",
        "    def dumps(self, obj: Any) -> str:\n",
        "        try:\n",
        "            if isinstance(obj, (bytes, bytearray)):\n",
        "                return obj.hex()\n",
        "            return super().dumps(obj)\n",
        "        except Exception as e:\n",
        "            logger.error(f\"Serialization error: {e}\")\n",
        "            raise\n",
        "\n",
        "    def loads(self, s: str, is_binary: bool = False) -> Any:\n",
        "        try:\n",
        "            if is_binary:\n",
        "                return bytes.fromhex(s)\n",
        "            return super().loads(s)\n",
        "        except Exception as e:\n",
        "            logger.error(f\"Deserialization error: {e}\")\n",
        "            raise\n",
        "\n",
        "\n",
        "def initialize_sync_pool(\n",
        "    host: str = \"localhost\", port: int = 6379, db: int = 0, **kwargs\n",
        ") -> redis.ConnectionPool:\n",
        "    \"\"\"Initialize a synchronous Redis connection pool.\"\"\"\n",
        "    try:\n",
        "        pool = redis.ConnectionPool(host=host, port=port, db=db, **kwargs)\n",
        "        logger.info(\n",
        "            f\"Synchronous Redis pool initialized with host={host}, port={port}, db={db}\"\n",
        "        )\n",
        "        return pool\n",
        "    except Exception as e:\n",
        "        logger.error(f\"Error initializing sync pool: {e}\")\n",
        "        raise\n",
        "\n",
        "\n",
        "def initialize_async_pool(\n",
        "    url: str = \"redis://localhost\", **kwargs\n",
        ") -> AsyncConnectionPool:\n",
        "    \"\"\"Initialize an asynchronous Redis connection pool.\"\"\"\n",
        "    try:\n",
        "        pool = AsyncConnectionPool.from_url(url, **kwargs)\n",
        "        logger.info(f\"Asynchronous Redis pool initialized with url={url}\")\n",
        "        return pool\n",
        "    except Exception as e:\n",
        "        logger.error(f\"Error initializing async pool: {e}\")\n",
        "        raise\n",
        "\n",
        "\n",
        "@contextmanager\n",
        "def _get_sync_connection(\n",
        "    connection: Union[redis.Redis, redis.ConnectionPool, None]\n",
        ") -> Generator[redis.Redis, None, None]:\n",
        "    conn = None\n",
        "    try:\n",
        "        if isinstance(connection, redis.Redis):\n",
        "            yield connection\n",
        "        elif isinstance(connection, redis.ConnectionPool):\n",
        "            conn = redis.Redis(connection_pool=connection)\n",
        "            yield conn\n",
        "        else:\n",
        "            raise ValueError(\"Invalid sync connection object.\")\n",
        "    except redis.ConnectionError as e:\n",
        "        logger.error(f\"Sync connection error: {e}\")\n",
        "        raise\n",
        "    finally:\n",
        "        if conn:\n",
        "            conn.close()\n",
        "\n",
        "\n",
        "@asynccontextmanager\n",
        "async def _get_async_connection(\n",
        "    connection: Union[AsyncRedis, AsyncConnectionPool, None]\n",
        ") -> AsyncGenerator[AsyncRedis, None]:\n",
        "    conn = None\n",
        "    try:\n",
        "        if isinstance(connection, AsyncRedis):\n",
        "            yield connection\n",
        "        elif isinstance(connection, AsyncConnectionPool):\n",
        "            conn = AsyncRedis(connection_pool=connection)\n",
        "            yield conn\n",
        "        else:\n",
        "            raise ValueError(\"Invalid async connection object.\")\n",
        "    except redis.ConnectionError as e:\n",
        "        logger.error(f\"Async connection error: {e}\")\n",
        "        raise\n",
        "    finally:\n",
        "        if conn:\n",
        "            await conn.aclose()\n",
        "\n",
        "\n",
        "class RedisSaver(BaseCheckpointSaver):\n",
        "    sync_connection: Optional[Union[redis.Redis, redis.ConnectionPool]] = None\n",
        "    async_connection: Optional[Union[AsyncRedis, AsyncConnectionPool]] = None\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        sync_connection: Optional[Union[redis.Redis, redis.ConnectionPool]] = None,\n",
        "        async_connection: Optional[Union[AsyncRedis, AsyncConnectionPool]] = None,\n",
        "    ):\n",
        "        super().__init__(serde=JsonAndBinarySerializer())\n",
        "        self.sync_connection = sync_connection\n",
        "        self.async_connection = async_connection\n",
        "\n",
        "    def put(\n",
        "        self,\n",
        "        config: RunnableConfig,\n",
        "        checkpoint: Checkpoint,\n",
        "        metadata: CheckpointMetadata,\n",
        "    ) -> RunnableConfig:\n",
        "        thread_id = config[\"configurable\"][\"thread_id\"]\n",
        "        parent_ts = config[\"configurable\"].get(\"thread_ts\")\n",
        "        key = f\"checkpoint:{thread_id}:{checkpoint['ts']}\"\n",
        "        try:\n",
        "            with _get_sync_connection(self.sync_connection) as conn:\n",
        "                conn.hset(\n",
        "                    key,\n",
        "                    mapping={\n",
        "                        \"checkpoint\": self.serde.dumps(checkpoint),\n",
        "                        \"metadata\": self.serde.dumps(metadata),\n",
        "                        \"parent_ts\": parent_ts if parent_ts else \"\",\n",
        "                    },\n",
        "                )\n",
        "                logger.info(\n",
        "                    f\"Checkpoint stored successfully for thread_id: {thread_id}, ts: {checkpoint['ts']}\"\n",
        "                )\n",
        "        except Exception as e:\n",
        "            logger.error(f\"Failed to put checkpoint: {e}\")\n",
        "            raise\n",
        "        return {\n",
        "            \"configurable\": {\n",
        "                \"thread_id\": thread_id,\n",
        "                \"thread_ts\": checkpoint[\"ts\"],\n",
        "            },\n",
        "        }\n",
        "\n",
        "    async def aput(\n",
        "        self,\n",
        "        config: RunnableConfig,\n",
        "        checkpoint: Checkpoint,\n",
        "        metadata: CheckpointMetadata,\n",
        "    ) -> RunnableConfig:\n",
        "        thread_id = config[\"configurable\"][\"thread_id\"]\n",
        "        parent_ts = config[\"configurable\"].get(\"thread_ts\")\n",
        "        key = f\"checkpoint:{thread_id}:{checkpoint['ts']}\"\n",
        "        try:\n",
        "            async with _get_async_connection(self.async_connection) as conn:\n",
        "                await conn.hset(\n",
        "                    key,\n",
        "                    mapping={\n",
        "                        \"checkpoint\": self.serde.dumps(checkpoint),\n",
        "                        \"metadata\": self.serde.dumps(metadata),\n",
        "                        \"parent_ts\": parent_ts if parent_ts else \"\",\n",
        "                    },\n",
        "                )\n",
        "                logger.info(\n",
        "                    f\"Checkpoint stored successfully for thread_id: {thread_id}, ts: {checkpoint['ts']}\"\n",
        "                )\n",
        "        except Exception as e:\n",
        "            logger.error(f\"Failed to aput checkpoint: {e}\")\n",
        "            raise\n",
        "        return {\n",
        "            \"configurable\": {\n",
        "                \"thread_id\": thread_id,\n",
        "                \"thread_ts\": checkpoint[\"ts\"],\n",
        "            },\n",
        "        }\n",
        "\n",
        "    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:\n",
        "        thread_id = config[\"configurable\"][\"thread_id\"]\n",
        "        thread_ts = config[\"configurable\"].get(\"thread_ts\", None)\n",
        "        try:\n",
        "            with _get_sync_connection(self.sync_connection) as conn:\n",
        "                if thread_ts:\n",
        "                    key = f\"checkpoint:{thread_id}:{thread_ts}\"\n",
        "                else:\n",
        "                    all_keys = conn.keys(f\"checkpoint:{thread_id}:*\")\n",
        "                    if not all_keys:\n",
        "                        logger.info(f\"No checkpoints found for thread_id: {thread_id}\")\n",
        "                        return None\n",
        "                    latest_key = max(all_keys, key=lambda k: k.decode().split(\":\")[-1])\n",
        "                    key = latest_key.decode()\n",
        "                checkpoint_data = conn.hgetall(key)\n",
        "                if not checkpoint_data:\n",
        "                    logger.info(f\"No valid checkpoint data found for key: {key}\")\n",
        "                    return None\n",
        "                checkpoint = self.serde.loads(checkpoint_data[b\"checkpoint\"].decode())\n",
        "                metadata = self.serde.loads(checkpoint_data[b\"metadata\"].decode())\n",
        "                parent_ts = checkpoint_data.get(b\"parent_ts\", b\"\").decode()\n",
        "                parent_config = (\n",
        "                    {\"configurable\": {\"thread_id\": thread_id, \"thread_ts\": parent_ts}}\n",
        "                    if parent_ts\n",
        "                    else None\n",
        "                )\n",
        "                logger.info(\n",
        "                    f\"Checkpoint retrieved successfully for thread_id: {thread_id}, ts: {thread_ts}\"\n",
        "                )\n",
        "                return CheckpointTuple(\n",
        "                    config=config,\n",
        "                    checkpoint=checkpoint,\n",
        "                    metadata=metadata,\n",
        "                    parent_config=parent_config,\n",
        "                )\n",
        "        except Exception as e:\n",
        "            logger.error(f\"Failed to get checkpoint tuple: {e}\")\n",
        "            raise\n",
        "\n",
        "    async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:\n",
        "        thread_id = config[\"configurable\"][\"thread_id\"]\n",
        "        thread_ts = config[\"configurable\"].get(\"thread_ts\", None)\n",
        "        try:\n",
        "            async with _get_async_connection(self.async_connection) as conn:\n",
        "                if thread_ts:\n",
        "                    key = f\"checkpoint:{thread_id}:{thread_ts}\"\n",
        "                else:\n",
        "                    all_keys = await conn.keys(f\"checkpoint:{thread_id}:*\")\n",
        "                    if not all_keys:\n",
        "                        logger.info(f\"No checkpoints found for thread_id: {thread_id}\")\n",
        "                        return None\n",
        "                    latest_key = max(all_keys, key=lambda k: k.decode().split(\":\")[-1])\n",
        "                    key = latest_key.decode()\n",
        "                checkpoint_data = await conn.hgetall(key)\n",
        "                if not checkpoint_data:\n",
        "                    logger.info(f\"No valid checkpoint data found for key: {key}\")\n",
        "                    return None\n",
        "                checkpoint = self.serde.loads(checkpoint_data[b\"checkpoint\"].decode())\n",
        "                metadata = self.serde.loads(checkpoint_data[b\"metadata\"].decode())\n",
        "                parent_ts = checkpoint_data.get(b\"parent_ts\", b\"\").decode()\n",
        "                parent_config = (\n",
        "                    {\"configurable\": {\"thread_id\": thread_id, \"thread_ts\": parent_ts}}\n",
        "                    if parent_ts\n",
        "                    else None\n",
        "                )\n",
        "                logger.info(\n",
        "                    f\"Checkpoint retrieved successfully for thread_id: {thread_id}, ts: {thread_ts}\"\n",
        "                )\n",
        "                return CheckpointTuple(\n",
        "                    config=config,\n",
        "                    checkpoint=checkpoint,\n",
        "                    metadata=metadata,\n",
        "                    parent_config=parent_config,\n",
        "                )\n",
        "        except Exception as e:\n",
        "            logger.error(f\"Failed to get checkpoint tuple: {e}\")\n",
        "            raise\n",
        "\n",
        "    def list(\n",
        "        self,\n",
        "        config: Optional[RunnableConfig],\n",
        "        *,\n",
        "        filter: Optional[dict[str, Any]] = None,\n",
        "        before: Optional[RunnableConfig] = None,\n",
        "        limit: Optional[int] = None,\n",
        "    ) -> Generator[CheckpointTuple, None, None]:\n",
        "        thread_id = config[\"configurable\"][\"thread_id\"] if config else \"*\"\n",
        "        pattern = f\"checkpoint:{thread_id}:*\"\n",
        "        try:\n",
        "            with _get_sync_connection(self.sync_connection) as conn:\n",
        "                keys = conn.keys(pattern)\n",
        "                if before:\n",
        "                    keys = [\n",
        "                        k\n",
        "                        for k in keys\n",
        "                        if k.decode().split(\":\")[-1]\n",
        "                        < before[\"configurable\"][\"thread_ts\"]\n",
        "                    ]\n",
        "                keys = sorted(\n",
        "                    keys, key=lambda k: k.decode().split(\":\")[-1], reverse=True\n",
        "                )\n",
        "                if limit:\n",
        "                    keys = keys[:limit]\n",
        "                for key in keys:\n",
        "                    data = conn.hgetall(key)\n",
        "                    if data and \"checkpoint\" in data and \"metadata\" in data:\n",
        "                        thread_ts = key.decode().split(\":\")[-1]\n",
        "                        yield CheckpointTuple(\n",
        "                            config={\n",
        "                                \"configurable\": {\n",
        "                                    \"thread_id\": thread_id,\n",
        "                                    \"thread_ts\": thread_ts,\n",
        "                                }\n",
        "                            },\n",
        "                            checkpoint=self.serde.loads(data[\"checkpoint\"].decode()),\n",
        "                            metadata=self.serde.loads(data[\"metadata\"].decode()),\n",
        "                            parent_config={\n",
        "                                \"configurable\": {\n",
        "                                    \"thread_id\": thread_id,\n",
        "                                    \"thread_ts\": data.get(\"parent_ts\", b\"\").decode(),\n",
        "                                }\n",
        "                            }\n",
        "                            if data.get(\"parent_ts\")\n",
        "                            else None,\n",
        "                        )\n",
        "                        logger.info(\n",
        "                            f\"Checkpoint listed for thread_id: {thread_id}, ts: {thread_ts}\"\n",
        "                        )\n",
        "        except Exception as e:\n",
        "            logger.error(f\"Failed to list checkpoints: {e}\")\n",
        "            raise\n",
        "\n",
        "    async def alist(\n",
        "        self,\n",
        "        config: Optional[RunnableConfig],\n",
        "        *,\n",
        "        filter: Optional[dict[str, Any]] = None,\n",
        "        before: Optional[RunnableConfig] = None,\n",
        "        limit: Optional[int] = None,\n",
        "    ) -> AsyncGenerator[CheckpointTuple, None]:\n",
        "        thread_id = config[\"configurable\"][\"thread_id\"] if config else \"*\"\n",
        "        pattern = f\"checkpoint:{thread_id}:*\"\n",
        "        try:\n",
        "            async with _get_async_connection(self.async_connection) as conn:\n",
        "                keys = await conn.keys(pattern)\n",
        "                if before:\n",
        "                    keys = [\n",
        "                        k\n",
        "                        for k in keys\n",
        "                        if k.decode().split(\":\")[-1]\n",
        "                        < before[\"configurable\"][\"thread_ts\"]\n",
        "                    ]\n",
        "                keys = sorted(\n",
        "                    keys, key=lambda k: k.decode().split(\":\")[-1], reverse=True\n",
        "                )\n",
        "                if limit:\n",
        "                    keys = keys[:limit]\n",
        "                for key in keys:\n",
        "                    data = await conn.hgetall(key)\n",
        "                    if data and \"checkpoint\" in data and \"metadata\" in data:\n",
        "                        thread_ts = key.decode().split(\":\")[-1]\n",
        "                        yield CheckpointTuple(\n",
        "                            config={\n",
        "                                \"configurable\": {\n",
        "                                    \"thread_id\": thread_id,\n",
        "                                    \"thread_ts\": thread_ts,\n",
        "                                }\n",
        "                            },\n",
        "                            checkpoint=self.serde.loads(data[\"checkpoint\"].decode()),\n",
        "                            metadata=self.serde.loads(data[\"metadata\"].decode()),\n",
        "                            parent_config={\n",
        "                                \"configurable\": {\n",
        "                                    \"thread_id\": thread_id,\n",
        "                                    \"thread_ts\": data.get(\"parent_ts\", b\"\").decode(),\n",
        "                                }\n",
        "                            }\n",
        "                            if data.get(\"parent_ts\")\n",
        "                            else None,\n",
        "                        )\n",
        "                        logger.info(\n",
        "                            f\"Checkpoint listed for thread_id: {thread_id}, ts: {thread_ts}\"\n",
        "                        )\n",
        "        except Exception as e:\n",
        "            logger.error(f\"Failed to list checkpoints: {e}\")\n",
        "            raise"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1d142495",
      "metadata": {},
      "source": [
        "## Checkpointer implementation"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "456fa19c-93a5-4750-a410-f2d810b964ad",
      "metadata": {},
      "source": [
        "## Setup environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "eca9aafb-a155-407a-8036-682a2f1297d7",
      "metadata": {},
      "outputs": [],
      "source": [
        "import getpass\n",
        "import os\n",
        "\n",
        "\n",
        "def _set_env(var: str):\n",
        "    if not os.environ.get(var):\n",
        "        os.environ[var] = getpass.getpass(f\"{var}: \")\n",
        "\n",
        "\n",
        "_set_env(\"OPENAI_API_KEY\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "e26b3204-cca2-414c-800e-7e09032445ae",
      "metadata": {},
      "source": [
        "## Setup model and tools for the graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "e5213193-5a7d-43e7-aeba-fe732bb1cd7a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from typing import Literal\n",
        "from langchain_core.runnables import ConfigurableField\n",
        "from langchain_core.tools import tool\n",
        "from langchain_openai import ChatOpenAI\n",
        "from langgraph.prebuilt import create_react_agent\n",
        "\n",
        "\n",
        "@tool\n",
        "def get_weather(city: Literal[\"nyc\", \"sf\"]):\n",
        "    \"\"\"Use this to get weather information.\"\"\"\n",
        "    if city == \"nyc\":\n",
        "        return \"It might be cloudy in nyc\"\n",
        "    elif city == \"sf\":\n",
        "        return \"It's always sunny in sf\"\n",
        "    else:\n",
        "        raise AssertionError(\"Unknown city\")\n",
        "\n",
        "\n",
        "tools = [get_weather]\n",
        "model = ChatOpenAI(model_name=\"gpt-4o\", temperature=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "e9342c62-dbb4-40f6-9271-7393f1ca48c4",
      "metadata": {},
      "source": [
        "## Use sync connection"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "e39fc712-9e1c-4831-9077-dd07b0c13594",
      "metadata": {},
      "source": [
        "### With a connection pool"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "a1710e2f",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:__main__:Synchronous Redis pool initialized with host=172.25.0.4, port=6379, db=0\n"
          ]
        }
      ],
      "source": [
        "sync_pool = initialize_sync_pool(host=\"172.25.0.4\", port=6379, db=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "2657c1c4-d8a5-4fe3-8f77-95415a98ed6c",
      "metadata": {},
      "outputs": [],
      "source": [
        "checkpointer = RedisSaver(sync_connection=sync_pool)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "id": "6d388241-de57-4b4e-af7b-eb1081fb8f36",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:__main__:Checkpoint retrieved successfully for thread_id: 1, ts: None\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:48.417492+00:00\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:48.420714+00:00\n",
            "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:49.458951+00:00\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:49.465101+00:00\n",
            "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 1, ts: 2024-07-09T08:22:50.084141+00:00\n"
          ]
        }
      ],
      "source": [
        "graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)\n",
        "config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
        "res = graph.invoke({\"messages\": [(\"human\", \"what's the weather in sf\")]}, config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "id": "a7e0e7ec-a675-470b-9270-e4bdc59d4a4d",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'messages': [HumanMessage(content=\"what's the weather in sf\", id='64df8e19-0b9f-47f7-928f-4db3255485aa'),\n",
              "  AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_n2XQOZHfpXpaNaviJakmjo82', 'function': {'arguments': '{\"city\":\"sf\"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 57, 'total_tokens': 71}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_ce0793330f', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-9057607f-6fa7-452b-95c7-f8f9832cb343-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'sf'}, 'id': 'call_n2XQOZHfpXpaNaviJakmjo82'}], usage_metadata={'input_tokens': 57, 'output_tokens': 14, 'total_tokens': 71}),\n",
              "  ToolMessage(content=\"It's always sunny in sf\", name='get_weather', id='444d80db-8230-440a-b0eb-46a3f4db1006', tool_call_id='call_n2XQOZHfpXpaNaviJakmjo82'),\n",
              "  AIMessage(content='The weather in San Francisco is currently sunny.', response_metadata={'token_usage': {'completion_tokens': 10, 'prompt_tokens': 84, 'total_tokens': 94}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d576307f90', 'finish_reason': 'stop', 'logprobs': None}, id='run-676a1ee4-7301-4405-93a4-87af27a92614-0', usage_metadata={'input_tokens': 84, 'output_tokens': 10, 'total_tokens': 94})]}"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "res"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "96efd8b2-97c9-4207-83b2-00131723a75a",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:__main__:Checkpoint retrieved successfully for thread_id: 1, ts: None\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "{'v': 1,\n",
              " 'ts': '2024-07-08T12:21:14.392158+00:00',\n",
              " 'id': '1ef3d249-1acf-60ed-bfff-248e42e4d9f5',\n",
              " 'channel_values': {'messages': [],\n",
              "  '__start__': {'messages': [['human', \"what's the weather in sf\"]]}},\n",
              " 'channel_versions': {'__start__': 1},\n",
              " 'versions_seen': {},\n",
              " 'pending_sends': []}"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "checkpointer.get(config)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "967c95c7-e392-4819-bd71-f29e91c68df3",
      "metadata": {},
      "source": [
        "### With a connection"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "id": "b7d3687b",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:__main__:Checkpoint retrieved successfully for thread_id: 2, ts: None\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 2, ts: 2024-07-09T08:22:50.132262+00:00\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 2, ts: 2024-07-09T08:22:50.135993+00:00\n",
            "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 2, ts: 2024-07-09T08:22:50.875540+00:00\n",
            "INFO:__main__:Checkpoint retrieved successfully for thread_id: 2, ts: None\n"
          ]
        }
      ],
      "source": [
        "import redis\n",
        "\n",
        "# Initialize the Redis synchronous direct connection\n",
        "sync_redis_direct = redis.Redis(host=\"172.25.0.4\", port=6379, db=0)\n",
        "\n",
        "# Initialize the RedisSaver with the synchronous direct connection\n",
        "checkpointer = RedisSaver(sync_connection=sync_redis_direct)\n",
        "\n",
        "graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)\n",
        "config = {\"configurable\": {\"thread_id\": \"2\"}}\n",
        "res = graph.invoke({\"messages\": [(\"human\", \"what's the weather in sf\")]}, config)\n",
        "\n",
        "checkpoint_tuple = checkpointer.get_tuple(config)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c0a47d3e-e588-48fc-a5d4-2145dff17e77",
      "metadata": {},
      "source": [
        "## Use async connection"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ee6b6cf7-d8f7-4777-a48d-93b5855fe681",
      "metadata": {},
      "source": [
        "### With a connection pool"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "id": "20cea8b7-8f13-4dc7-a3c9-825040eb4c57",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:__main__:Asynchronous Redis pool initialized with url=redis://172.25.0.4:6379/0\n"
          ]
        }
      ],
      "source": [
        "# Initialize a synchronous Redis connection pool\n",
        "async_pool = initialize_async_pool(url=\"redis://172.25.0.4:6379/0\")\n",
        "\n",
        "checkpointer = RedisSaver(async_connection=async_pool)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "f889dce6-7ec1-4277-b8af-ace7811733fa",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:__main__:Checkpoint retrieved successfully for thread_id: 3, ts: None\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:50.949172+00:00\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:50.951824+00:00\n",
            "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:51.698633+00:00\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:51.702156+00:00\n",
            "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 3, ts: 2024-07-09T08:22:53.530983+00:00\n"
          ]
        }
      ],
      "source": [
        "graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)\n",
        "config = {\"configurable\": {\"thread_id\": \"3\"}}\n",
        "res = await graph.ainvoke(\n",
        "    {\"messages\": [(\"human\", \"what's the weather in nyc\")]}, config\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "id": "ed58c722-1662-4ae2-9bb7-4872158a5b29",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:__main__:Checkpoint retrieved successfully for thread_id: 3, ts: None\n"
          ]
        }
      ],
      "source": [
        "checkpoint_tuple = await checkpointer.aget_tuple(config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "id": "e0c42044-4de6-4742-8e00-fe295d50c95a",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "CheckpointTuple(config={'configurable': {'thread_id': '3'}}, checkpoint={'v': 1, 'ts': '2024-07-08T12:21:18.866666+00:00', 'id': '1ef3d249-457b-62d3-bfff-b0e787336a7c', 'channel_values': {'messages': [], '__start__': {'messages': [['human', \"what's the weather in nyc\"]]}}, 'channel_versions': {'__start__': 1}, 'versions_seen': {}, 'pending_sends': []}, metadata={'source': 'input', 'step': -1, 'writes': {'messages': [['human', \"what's the weather in nyc\"]]}}, parent_config=None)"
            ]
          },
          "execution_count": 16,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "checkpoint_tuple"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "56552584-9eb8-40df-a6a0-44151018b509",
      "metadata": {},
      "source": [
        "### Use connection"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "id": "a7bf32bd",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:__main__:Checkpoint retrieved successfully for thread_id: 4, ts: None\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:53.585109+00:00\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:53.587207+00:00\n",
            "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:54.932663+00:00\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:54.936425+00:00\n",
            "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
            "INFO:__main__:Checkpoint stored successfully for thread_id: 4, ts: 2024-07-09T08:22:55.982495+00:00\n"
          ]
        }
      ],
      "source": [
        "from redis.asyncio import Redis as AsyncRedis\n",
        "\n",
        "async with await AsyncRedis(host=\"172.25.0.4\", port=6379, db=0) as conn:\n",
        "    checkpointer = RedisSaver(async_connection=conn)\n",
        "    graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)\n",
        "    config = {\"configurable\": {\"thread_id\": \"4\"}}\n",
        "    res = await graph.ainvoke(\n",
        "        {\"messages\": [(\"human\", \"what's the weather in nyc\")]}, config\n",
        "    )\n",
        "    checkpoint_tuples = [c async for c in checkpointer.alist(config)]"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
